/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

/*!
 * \file tir/ir/transform.cc
 * \brief TIR specific transformation passes.
 */
#include <tvm/ffi/function.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/ffi/rvalue_ref.h>
#include <tvm/node/repr_printer.h>
#include <tvm/tir/transform.h>

namespace tvm {
namespace tir {
namespace transform {

// Register build pipeline related options
TVM_REGISTER_PASS_CONFIG_OPTION("tir.noalias", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.detect_global_barrier", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.instrument_bound_checkers", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_assert", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_vectorize", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.enable_buffer_level_predication", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_cse_tir", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.enable_debug", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.enable_equiv_terms_in_cse_tir", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_storage_rewrite", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.is_entry_func", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.add_lower_pass", ffi::Array<ffi::Array<ObjectRef>>);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.debug_keep_trivial_loop", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.use_async_copy", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.merge_static_smem", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.instrument_lwp", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.vtcm_capacity", Integer);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.ptx_ldg32", Bool);

/*!
 * \brief Function level pass that applies transformations to all
 *        TIR functions within the module.
 */
class PrimFuncPassNode : public PassNode {
 public:
  /* \brief The pass meta data.*/
  PassInfo pass_info;

  /*! \brief The pass function called on each. */
  std::function<PrimFunc(PrimFunc, IRModule, PassContext)> pass_func;

  static void RegisterReflection() {
    namespace refl = tvm::ffi::reflection;
    refl::ObjectDef<PrimFuncPassNode>().def_ro("pass_info", &PrimFuncPassNode::pass_info);
  }

  /*!
   * \brief Run a function pass on given pass context.
   *
   * \param mod The module that an optimization pass is applied on.
   * \param pass_ctx The context that an optimization pass executes on.
   *
   * \return Return the updated module.
   */
  IRModule operator()(IRModule mod, const PassContext& pass_ctx) const final;

  /*!
   * \brief Get the pass information/meta data.
   */
  PassInfo Info() const override { return pass_info; }
  TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.PrimFuncPass", PrimFuncPassNode, PassNode);
};

class PrimFuncPass : public Pass {
 public:
  /*!
   * \brief The constructor
   * \param pass_func The packed function which implements a pass.
   * \param pass_info The pass info.
   */
  TVM_DLL PrimFuncPass(std::function<PrimFunc(PrimFunc, IRModule, PassContext)> pass_func,
                       PassInfo pass_info);

  TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(PrimFuncPass, Pass, PrimFuncPassNode);
};

PrimFuncPass::PrimFuncPass(std::function<PrimFunc(PrimFunc, IRModule, PassContext)> pass_func,
                           PassInfo pass_info) {
  auto n = ffi::make_object<PrimFuncPassNode>();
  n->pass_func = std::move(pass_func);
  n->pass_info = std::move(pass_info);
  data_ = std::move(n);
}

// Perform Module -> Module optimizations at the PrimFunc level.
IRModule PrimFuncPassNode::operator()(IRModule mod, const PassContext& pass_ctx) const {
  ICHECK(mod.defined());
  std::vector<GlobalVar> deleted_list;

  IRModuleNode* mod_ptr = mod.CopyOnWrite();
  auto* func_dict = mod_ptr->functions.CopyOnWrite();
  // directly loop over the underlying dict
  for (auto& kv : *func_dict) {
    // only picks up tir::PrimFunc
    if (auto opt_func = kv.second.as<PrimFunc>()) {
      // reset the original Any state so the value contains only copy
      // use move semantics as follows to avoid only copy.
      kv.second.reset();
      PrimFunc func = *std::move(opt_func);
      func = pass_func(std::move(func), mod, pass_ctx);
      kv.second = Any(std::move(func));
      if (kv.second == nullptr) {
        deleted_list.push_back(Downcast<GlobalVar>(kv.first));
      }
    }
  }

  // Automatic removal of None.  This uses IRModuleNode::Remove
  // instead of manipulating func_dict directly, to ensure that both
  // the function map and the global_var_map_ are correctly updated.
  for (const auto& gv : deleted_list) {
    mod_ptr->Remove(gv);
  }
  return mod;
}

Pass CreatePrimFuncPass(std::function<PrimFunc(PrimFunc, IRModule, PassContext)> pass_func,
                        int opt_level, ffi::String name, tvm::ffi::Array<ffi::String> required,
                        bool traceable) {
  PassInfo pass_info = PassInfo(opt_level, name, required, traceable);
  return PrimFuncPass(std::move(pass_func), pass_info);
}

TVM_FFI_STATIC_INIT_BLOCK() { PrimFuncPassNode::RegisterReflection(); }

TVM_FFI_STATIC_INIT_BLOCK() {
  namespace refl = tvm::ffi::reflection;
  refl::GlobalDef().def(
      "tir.transform.CreatePrimFuncPass",
      [](ffi::TypedFunction<PrimFunc(ffi::RValueRef<PrimFunc>, IRModule, PassContext)> pass_func,
         PassInfo pass_info) {
        auto wrapped_pass_func = [pass_func](PrimFunc func, IRModule mod, PassContext ctx) {
          return pass_func(ffi::RValueRef<PrimFunc>(std::move(func)), mod, ctx);
        };
        return PrimFuncPass(wrapped_pass_func, pass_info);
      });
}

TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
    .set_dispatch<PrimFuncPassNode>([](const ObjectRef& ref, ReprPrinter* p) {
      auto* node = static_cast<const PrimFuncPassNode*>(ref.get());
      const PassInfo info = node->Info();
      p->stream << "PrimFuncPass(" << info->name << ", opt_level=" << info->opt_level << ")";
    });

}  // namespace transform
}  // namespace tir
}  // namespace tvm
